/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <cstdio>
#include <cstdlib>
#include "acl/acl_rt.h"
#include "acl/acl.h"
#include "aclnn_normalize_custom.h"
#include "../../../../../common/data_utils.h"

constexpr uint8_t SRC_SIZE = 5;
constexpr uint16_t TIMEOUT = 500;
constexpr uint8_t INDEX_IN_VAR = 2;
constexpr uint8_t INDEX_OUT = 5;
constexpr uint8_t INDEX_OUT_RSTD = 6;
constexpr uint8_t RN_SIZE = 1;
constexpr uint32_t A_SIZE = 8;
constexpr uint32_t R_SIZE = 64;
constexpr uint32_t R_SIZE_WITH_PAD = 64;

aclrtStream CreateStream(int device)
{
    if (aclInit(NULL) != ACL_SUCCESS) {
        printf("acl init failed\n");
        return NULL;
    }
    if (aclrtSetDevice(device) != ACL_SUCCESS) {
        printf("Set device failed\n");
        (void)aclFinalize();
        return NULL;
    }
    aclrtStream stream = nullptr;
    if (aclrtCreateStream(&stream) != ACL_SUCCESS) {
        printf("Create stream failed\n");
        return NULL;
    }
    return stream;
}

void DestroyStream(aclrtStream stream, int device)
{
    (void)aclrtDestroyStream(stream);
    if (aclrtResetDevice(device) != ACL_SUCCESS) {
        printf("Reset device failed\n");
    }
    if (aclFinalize() != ACL_SUCCESS) {
        printf("Finalize acl failed\n");
    }
}

struct tensorInfo {
    int64_t *dims;
    int64_t dimCnt;
    aclDataType dtype;
    aclFormat fmt;
};

int64_t GetDataSize(struct tensorInfo *desc)
{
    if (!desc->dims) {
        return 0;
    }
    int64_t size = 1;
    for (auto i = 0; i < desc->dimCnt; i++) {
        size *= desc->dims[i];
    }
    return size * sizeof(float);
}

static bool CompareResult(const void *outputData, int64_t outSize, std::string goldenName)
{
    void *goldenData;
    CHECK_ACL(aclrtMallocHost((void **)(&goldenData), outSize));
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden_%s.bin success!\n", goldenName.c_str());
    } else {
        printf("test failed!\n");
        return false;
    }
    constexpr float EPS = 1e-4;
    int64_t wrongNum = 0;
    for (int i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float *>(outputData))[i];
        float b = (reinterpret_cast<const float *>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden_output_%s.bin failed output is %lf, golden is %lf, index is %d\n", goldenName.c_str(), a,
            b, i);
            wrongNum++;
        }
    }
    CHECK_ACL(aclrtFreeHost(goldenData));

    if (wrongNum != 0) {
        return false;
    } else {
        printf("CompareResult golden_output_%s.bin success\n", goldenName.c_str());
        return true;
    }
}

int main(void) {
    aclrtStream stream;

    int64_t srcGm[] = {A_SIZE * R_SIZE_WITH_PAD};
    int64_t inMeanGm[] = {A_SIZE};
    int64_t inVarGm[] = {A_SIZE};
    int64_t inGammaGm[] = {R_SIZE_WITH_PAD};
    int64_t inBetaGm[] = {R_SIZE_WITH_PAD};
    int64_t outGm[] = {A_SIZE * R_SIZE_WITH_PAD};
    int64_t outRstdGm[] = {A_SIZE};
    struct tensorInfo tensorDesc[] = {
        {srcGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
        {inMeanGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
        {inVarGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
        {inGammaGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
        {inBetaGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
        {outGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
        {outRstdGm, 1, ACL_FLOAT, ACL_FORMAT_ND},
    };

    std::string ParamNames[] = {
        "srcGm",
        "inMeanGm",
        "inVarGm",
        "inGammaGm",
        "inBetaGm",
        "outGm",
        "outRstdGm",
    };
    stream = CreateStream(0);

    aclTensor *tensors[sizeof(tensorDesc) / sizeof(struct tensorInfo)];
    void *devMem[sizeof(tensorDesc) / sizeof(struct tensorInfo)];
    for (auto i = 0; i < sizeof(tensorDesc) / sizeof(struct tensorInfo); i++) {
        void *data;
        struct tensorInfo *info = &(tensorDesc[i]);
        int64_t size = GetDataSize(info);
        if (size == 0) {
            tensors[i] = NULL;
            devMem[i] = NULL;
            continue;
        }
        CHECK_ACL(aclrtMalloc(&data, size, ACL_MEM_MALLOC_HUGE_FIRST));
        // read input
        if (i < SRC_SIZE) {
            size_t inputSize = size;
            void *dataHost;
            CHECK_ACL(aclrtMallocHost((void **)(&dataHost), inputSize));
            ReadFile("../input/input_" + ParamNames[i] + ".bin", inputSize, dataHost, inputSize);
            CHECK_ACL(aclrtMemcpy(data, size, dataHost, size, ACL_MEMCPY_HOST_TO_DEVICE));
            CHECK_ACL(aclrtFreeHost(dataHost));
        }
        devMem[i] = data;
        tensors[i] =
            aclCreateTensor(info->dims, info->dimCnt, info->dtype, NULL, 0, info->fmt, info->dims, info->dimCnt, data);
    }

    size_t workspaceSize = 0;
    aclOpExecutor *handle;
    int32_t ret = 0;
    ret = aclnnNormalizeCustomGetWorkspaceSize(tensors[0], tensors[1], tensors[2],
        tensors[3], tensors[4], tensors[5], tensors[6], &workspaceSize, &handle);
    printf("aclnnNormalizeCustomGetWorkspaceSize ret %u workspace size %lu\n", ret, workspaceSize);
    void *workspace = nullptr;
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    ret = aclnnNormalizeCustom(workspace, workspaceSize, handle, stream);
    printf("aclnnNormalizeCustom ret %u\n", ret);
    if (aclrtSynchronizeStreamWithTimeout(stream, TIMEOUT) != ACL_SUCCESS) {
        printf("Synchronize stream failed\n");
    }

    uint8_t *outHost, *outRstdHost;
    int64_t outHostSize = GetDataSize(&(tensorDesc[5]));
    int64_t outRstdHostSize = GetDataSize(&(tensorDesc[6]));
    CHECK_ACL(aclrtMallocHost((void **)(&outHost), outHostSize));
    CHECK_ACL(aclrtMallocHost((void **)(&outRstdHost), outRstdHostSize));

    CHECK_ACL(aclrtMemcpy(outHost, outHostSize, devMem[5], outHostSize, ACL_MEMCPY_DEVICE_TO_HOST));
    CHECK_ACL(aclrtMemcpy(outRstdHost, outRstdHostSize, devMem[6], outRstdHostSize,
        ACL_MEMCPY_DEVICE_TO_HOST));
    WriteFile("../output/outGm.bin", outHost, outHostSize);
    WriteFile("../output/outRstdGm.bin", outRstdHost, outRstdHostSize);

    bool goldenResult = true;
    goldenResult &= CompareResult(outHost, outHostSize, ParamNames[5]);
    goldenResult &= CompareResult(outRstdHost, outRstdHostSize, ParamNames[6]);
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }

    CHECK_ACL(aclrtFreeHost(outHost));
    CHECK_ACL(aclrtFreeHost(outRstdHost));

    for (auto i = 0; i < sizeof(tensorDesc) / sizeof(struct tensorInfo); i++) {
        if (!tensors[i])
            continue;
        if (devMem[i]) {
            CHECK_ACL(aclrtFree(devMem[i]));
        }
        aclDestroyTensor(tensors[i]);
    }
    DestroyStream(stream, 0);
    return 0;
}
